import argparse
import cPickle as pkl
import numpy
import theano

from collections import OrderedDict
from theano import tensor

import lm

profile = False


# load parameters
def load_params(path):
    pp = numpy.load(path)
    params = OrderedDict()
    for kk, vv in pp.iteritems():
        params[kk] = vv

    return params


# Loads a pickled dictionary
def load_dictionary(filename):
    with open(filename, 'rb') as f:
        word_dict = pkl.load(f)
    return word_dict


# Inverts a dictionary and ensures special tokens
def invert_dictionary(word_dict):
    word_idict = dict()
    for kk, vv in word_dict.iteritems():
        word_idict[vv] = kk
    word_idict[0] = '<eos>'
    word_idict[1] = 'UNK'
    return word_idict


# initialize Theano shared variables according to the initial parameters
def init_tparams(params):
    tparams = OrderedDict()
    for kk, pp in params.iteritems():
        tparams[kk] = theano.shared(params[kk], name=kk)
    return tparams


# layers: 'name': ('parameter initializer', 'feedforward')
layers = {'ff': ('lm.param_init_fflayer', 'lm.fflayer'),
          'gru': ('lm.param_init_gru', 'lm.gru_layer'),
          }


# Utility function to get layer props
def get_layer(name):
    fns = layers[name]
    return (eval(fns[0]), eval(fns[1]))


# build a sampler
def build_sampler(tparams, options):
    # x: 1 x 1
    y = tensor.vector('y_sampler', dtype='int64')
    init_state = tensor.matrix('init_state', dtype='float32')

    # if it's the first word, emb should be all zero
    emb = tensor.switch(y[:, None] < 0,
                        tensor.alloc(0., 1, tparams['Wemb'].shape[1]),
                        tparams['Wemb'][y])
    proj = get_layer(options['encoder'])[1](tparams, emb, options,
                                            prefix='encoder',
                                            mask=None,
                                            one_step=True,
                                            init_state=init_state)
    next_state = proj[0]

    logit_lstm = get_layer('ff')[1](tparams, next_state, options,
                                    prefix='ff_logit_lstm', activ='linear')
    logit_prev = get_layer('ff')[1](tparams, emb, options,
                                    prefix='ff_logit_prev', activ='linear')
    logit = tensor.tanh(logit_lstm+logit_prev)
    logit = get_layer('ff')[1](tparams, logit, options,
                               prefix='ff_logit', activ='linear')
    next_probs = tensor.nnet.softmax(logit)

    # next word probability
    print 'Building f_next..',
    inps = [y, init_state]
    outs = [next_probs, next_state]
    f_next = theano.function(inps, outs, name='f_next', profile=profile)
    print 'Done'

    return f_next


# Scores a given sequence with the language model
def score_seq(seq, f_next, options, normalize):

    next_w = -1 * numpy.ones((1,)).astype('int64')
    next_state = numpy.zeros((1, options['dim'])).astype('float32')

    seq_len = len(seq)
    sample_score = 0
    for ii in xrange(seq_len):
        inps = [next_w, next_state]
        ret = f_next(*inps)
        next_p, next_state = ret[0], ret[1]

        # accumulate nll for each token
        sample_score -= numpy.log(next_p[0, seq[ii]])

    if normalize:
        sample_score /= seq_len

    return sample_score


# Linearly interpolate between two scores using beta
def shallow_fusion(score_lm, score_tm, beta, convex_comb):
    if convex_comb:
        return (1 - beta) * score_tm + (beta * score_lm)
    return score_tm + (beta * score_lm)


def main(model, model_options, dictionary_lm, dictionary_tm,
         source, saveto, normalize=False, chr_level=False,
         beta=0.5, convex_comb=False):

    # load model options
    model_options = pkl.load(open(model_options))

    # reload parameters
    print 'Loading language model..',
    params = load_params(model)
    tparams = init_tparams(params)
    print 'Done'

    print 'Loading LM dictionary..',
    word_dict_lm = load_dictionary(dictionary_lm)
    print 'Done'

    print 'Loading TM dictionary..',
    word_dict_tm = load_dictionary(dictionary_tm)
    word_idict_tm = invert_dictionary(word_dict_tm)
    print 'Done'

    f_next = build_sampler(tparams, model_options)

    # Create a cross dictionary from tm to lm
    tm2lm_idx = {}
    for idx, word in word_idict_tm.items():
        tm2lm_idx[idx] = word_dict_lm.get(word, 1)
    tm2lm_idx[0] = 0  # <eos>
    tm2lm_idx[1] = 1  # UNK

    # Iterate over the n-best list generated by TM
    print 'Rescoring..',
    new_trans = []
    nbest_idx = 0
    with open(source, 'r') as f:
        scores_in_nbest = []
        trans_in_nbest = []
        for idx, line in enumerate(f):
            line_idx, trans, score_tm = line.strip().split('|||')
            if chr_level:
                words = list(trans.decode('utf-8').strip())
            else:
                words = line.strip().split()
            x = map(lambda w: word_dict_tm[w]
                    if w in word_dict_tm else 1, words)
            x = map(lambda ii: ii if ii < model_options['n_words'] else 1, x)
            x += [0]

            # Score the sequence with LM
            x_lm = [tm2lm_idx[xx] for xx in x]
            score_lm = score_seq(x_lm, f_next, model_options, normalize)

            # Take linear interpolation with Beta
            new_score = shallow_fusion(score_lm, float(score_tm),
                                       beta, convex_comb)
            if int(line_idx) > nbest_idx:
                new_trans.append(
                    trans_in_nbest[numpy.argmin(scores_in_nbest)])

                scores_in_nbest = []
                trans_in_nbest = []
                nbest_idx += 1
            else:
                scores_in_nbest.append(new_score)
                trans_in_nbest.append(trans)

    print 'Done'
    print 'Saving to %s' % saveto
    with open(saveto, 'w') as f:
        print >>f, '\n'.join(new_trans)
    print 'Done'
    return


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('-b', '--beta', type=float, default=1.,
                        help="Weight for language model score")
    parser.add_argument('-n', action="store_true", default=False,
                        help="Normalize wrt sequence length")
    parser.add_argument('-c', action="store_true", default=False,
                        help="Character level")
    parser.add_argument('-x', action="store_true", default=False,
                        help="Take convex combination using beta")
    parser.add_argument('model', type=str)
    parser.add_argument('model_options', type=str)
    parser.add_argument('dictionary_lm', type=str,
                        help='Dictionary of language model')
    parser.add_argument('dictionary_tm', type=str,
                        help='Target side dictionary of translation model')
    parser.add_argument('source', type=str)
    parser.add_argument('saveto', type=str)

    args = parser.parse_args()

    main(args.model, args.model_options,
         args.dictionary_lm, args.dictionary_tm,
         args.source, args.saveto,
         normalize=args.n, chr_level=args.c, beta=args.beta,
         convex_comb=args.x)
